package com.micro.platform.gateway.filter.encryption;


import com.alibaba.fastjson.JSON;
import com.micro.platform.starter.utils.RSACoderUtil;
import com.micro.platform.starter.utils.StringUtils;
import io.netty.buffer.ByteBufAllocator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.factory.rewrite.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.DefaultServerRequest;
import org.springframework.core.Ordered;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

/**
 * 对所有的request请求的body进行解密
 * @date 2020/12/4 17:34
 * @author wei.heng
 */
@Component
public class GlobalRequestBodyDecodeFilter implements GlobalFilter, Ordered {

    private Logger log = LoggerFactory.getLogger(GlobalRequestBodyDecodeFilter.class);

    @Value("${encryption.switch:false}")
    private boolean encryptionSwitch;
    @Value("${encryption.private-key:null}")
    private String privateKey;
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        if(encryptionSwitch && StringUtils.isNotEmpty(privateKey)){
            ServerHttpRequest request = exchange.getRequest();
            if((request.getMethod() == HttpMethod.POST
                    || request.getMethod() == HttpMethod.PUT
                    || request.getMethod() == HttpMethod.PATCH)){
                return operationExchange(exchange, chain);
            }
        }
        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return Integer.MIN_VALUE + 1000;
    }

    private Mono<Void> operationExchange(ServerWebExchange exchange, GatewayFilterChain chain) {
        // 缓存请求路径
        ServerHttpRequest request = exchange.getRequest();
        String path = request.getPath().pathWithinApplication().value();
        HttpHeaders headers = request.getHeaders();
        MediaType contentType = headers.getContentType();
        Mono<Void> voidMono = null;
        if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_JSON_UTF8.equals(contentType)) {
            // 对applicationjson类型数据做解密操作
            voidMono = readBody(exchange, chain);
        }if (MediaType.APPLICATION_FORM_URLENCODED.equals(contentType)
                || (contentType != null && contentType.toString().startsWith(MediaType.MULTIPART_FORM_DATA_VALUE)) ) {
            voidMono = readFormData(exchange, chain);
        }
        if(voidMono == null){
            voidMono = chain.filter(exchange);
        }
        return voidMono;
    }

    private Mono<Void> readFormData(ServerWebExchange exchange, GatewayFilterChain chain) {
        // form data
        MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
        final ServerHttpRequest request = exchange.getRequest();
        HttpHeaders headers = request.getHeaders();
        return exchange.getFormData()
                .doOnNext(multiValueMap -> { formData.addAll(multiValueMap); })
                .then(Mono.defer(() -> {
                    Charset charset = headers.getContentType().getCharset();
                    charset = charset == null ? StandardCharsets.UTF_8 : charset;
                    String charsetName = charset.name();
                    /** formData is empty just return */
                    if (null == formData || formData.isEmpty()) {
                        return chain.filter(exchange);
                    }
                    StringBuilder formDataBodyBuilder = new StringBuilder();
                    String entryKey;
                    List<String> entryValue;
                    /* repackage form data */
                    for (Map.Entry<String, List<String>> entry : formData
                            .entrySet()) {
                        entryKey = entry.getKey();
                        entryValue = entry.getValue();
                        if (entryValue.size() > 1) {
                            for (String value : entryValue) {
                                formDataBodyBuilder.append(entryKey).append("=")
                                        .append(decodeForm(value, privateKey))
                                        .append("&");
                            }
                        } else {
                            formDataBodyBuilder
                                    .append(entryKey).append("=")
                                    .append(decodeForm(entryValue.get(0), privateKey))
                                    .append("&");
                        }
                    }
                    /* substring with the last char '&' */
                    String formDataBodyString = "";
                    if (formDataBodyBuilder.length() > 0) {
                        formDataBodyString = formDataBodyBuilder.substring(0,
                                formDataBodyBuilder.length() - 1);
                    }
                    // date data bytes
                    byte[] bodyBytes = formDataBodyString.getBytes(charset);
                    int contentLength = bodyBytes.length;
                    ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator( request) {
                        @Override
                        public HttpHeaders getHeaders() {
                            HttpHeaders httpHeaders = new HttpHeaders();
                            httpHeaders.putAll(super.getHeaders());
                            if (contentLength > 0) {
                                httpHeaders.setContentLength(contentLength);
                            } else {
                                httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                            }
                            return httpHeaders;
                        }
                        @Override
                        public Flux<DataBuffer> getBody() {
                            return DataBufferUtils.read(new ByteArrayResource(bodyBytes), new NettyDataBufferFactory(ByteBufAllocator.DEFAULT), contentLength);
                        }
                    };
                    ServerWebExchange mutateExchange = exchange.mutate().request(decorator).build();
                    return chain.filter(mutateExchange);
                }));
    }

    private Mono<Void> readBody(ServerWebExchange exchange, GatewayFilterChain chain) {
        final ServerHttpRequest request = exchange.getRequest();
        MediaType mediaType = request.getHeaders().getContentType();
        // 修改响应体
        ServerRequest serverRequest = new DefaultServerRequest(exchange);
        Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
                .flatMap(body -> {
                    if (MediaType.APPLICATION_JSON.isCompatibleWith(mediaType) && StringUtils.isNotEmpty(body)) {
                        // 拿到公钥加密后的数据
                        return Mono.just(decodeBody(body, privateKey));
                    }
                    return Mono.empty();
                });
        BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
        HttpHeaders headers = new HttpHeaders();
        headers.putAll(request.getHeaders());
        headers.remove(HttpHeaders.CONTENT_LENGTH);
        CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, headers);
        return bodyInserter.insert(outputMessage, new BodyInserterContext())
                .then(Mono.defer(() -> {
                    ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(request) {
                        @Override
                        public HttpHeaders getHeaders() {
                            long contentLength = headers.getContentLength();
                            HttpHeaders httpHeaders = new HttpHeaders();
                            httpHeaders.putAll(super.getHeaders());
                            if (contentLength > 0) {
                                httpHeaders.setContentLength(contentLength);
                            } else {
                                httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                            }
                            return httpHeaders;
                        }
                        @Override
                        public Flux<DataBuffer> getBody() {
                            return outputMessage.getBody();
                        }
                    };
                    return chain.filter(exchange.mutate().request(decorator).build());
                }));
    }

    private String decodeBody(String body, String privateKeyStr){
        Map<String,String> parse = (Map) JSON.parse(body);
        String encryptData = parse.get("encryptData");
        // 拿到公钥加密后的数据
        byte[] decodedBytes = new byte[0];
        try {
            decodedBytes = RSACoderUtil.decryptByPrivateKey(encryptData, privateKeyStr);
        } catch (Exception e) {
            e.printStackTrace();
            log.error("解密异常：" + e.getMessage());
            throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "响应体解密异常");
        }
        return new String(decodedBytes, StandardCharsets.UTF_8);
    }
    private String decodeForm(String formStr, String privateKeyStr){
        byte[] decodedBytes = new byte[0];
        try {
            decodedBytes = RSACoderUtil.decryptByPrivateKey(formStr, privateKeyStr);
        } catch (Exception e) {
            e.printStackTrace();
            log.error("解密异常：" + e.getMessage());
            throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "响应体解密异常");
        }
        return new String(decodedBytes, StandardCharsets.UTF_8);
    }
}